//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the main Comb to Synth Conversion Pass Implementation.
//
//  High-level Comb Operations
//             |             |
//             v             |
//   +-------------------+   |
//   | and, or, xor, mux |   |
//   +---------+---------+   |
//             |             |
//     +-------+--------+    |
//     v                v    v
//     +-----+         +-----+
//     | AIG |-------->| MIG |
//     +-----+         +-----+
//
//===----------------------------------------------------------------------===//

#include "circt/Conversion/CombToSynth.h"
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/Datapath/DatapathOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Synth/SynthDialect.h"
#include "circt/Dialect/Synth/SynthOps.h"
#include "circt/Support/Naming.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/Support/Debug.h"
#include <array>

#define DEBUG_TYPE "comb-to-synth"

namespace circt {
#define GEN_PASS_DEF_CONVERTCOMBTOSYNTH
#include "circt/Conversion/Passes.h.inc"
} // namespace circt

using namespace circt;
using namespace comb;

//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//

// A wrapper for comb::extractBits that returns a SmallVector<Value>.
static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
  SmallVector<Value> bits;
  comb::extractBits(builder, val, bits);
  return bits;
}

// Construct a mux tree for shift operations. `isLeftShift` controls the
// direction of the shift operation and is used to determine order of the
// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
// to get the padding and extracted bits for each shift amount. `getPadding`
// could return a nullptr as i0 value but except for that, these callbacks must
// return a valid value for each shift amount in the range [0, maxShiftAmount].
// The value for `maxShiftAmount` is used as the out-of-bounds value.
template <bool isLeftShift>
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
                              Value shiftAmount, int64_t maxShiftAmount,
                              llvm::function_ref<Value(int64_t)> getPadding,
                              llvm::function_ref<Value(int64_t)> getExtract) {
  // Extract individual bits from shift amount
  auto bits = extractBits(rewriter, shiftAmount);

  // Create nodes for each possible shift amount
  SmallVector<Value> nodes;
  nodes.reserve(maxShiftAmount);
  for (int64_t i = 0; i < maxShiftAmount; ++i) {
    Value extract = getExtract(i);
    Value padding = getPadding(i);

    if (!padding) {
      nodes.push_back(extract);
      continue;
    }

    // Concatenate extracted bits with padding
    if (isLeftShift)
      nodes.push_back(
          rewriter.createOrFold<comb::ConcatOp>(loc, extract, padding));
    else
      nodes.push_back(
          rewriter.createOrFold<comb::ConcatOp>(loc, padding, extract));
  }

  // Create out-of-bounds value
  auto outOfBoundsValue = getPadding(maxShiftAmount);
  assert(outOfBoundsValue && "outOfBoundsValue must be valid");

  // Construct mux tree for shift operation
  auto result =
      comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);

  // Add bounds checking
  auto inBound = rewriter.createOrFold<comb::ICmpOp>(
      loc, ICmpPredicate::ult, shiftAmount,
      hw::ConstantOp::create(rewriter, loc, shiftAmount.getType(),
                             maxShiftAmount));

  return rewriter.createOrFold<comb::MuxOp>(loc, inBound, result,
                                            outOfBoundsValue);
}

// Return a majority operation if MIG is enabled, otherwise return a majority
// function implemented with Comb operations. In that case `carry` has slightly
// smaller depth than the other inputs.
static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a,
                                    Value b, Value carry,
                                    bool useMajorityInverterOp) {
  if (useMajorityInverterOp) {
    std::array<Value, 3> inputs = {a, b, carry};
    std::array<bool, 3> inverts = {false, false, false};
    return synth::mig::MajorityInverterOp::create(rewriter, loc, inputs,
                                                  inverts);
  }

  // maj(a, b, c) = (c & (a ^ b)) | (a & b)
  auto aXnorB = comb::XorOp::create(rewriter, loc, ValueRange{a, b}, true);
  auto andOp =
      comb::AndOp::create(rewriter, loc, ValueRange{carry, aXnorB}, true);
  auto aAndB = comb::AndOp::create(rewriter, loc, ValueRange{a, b}, true);
  return comb::OrOp::create(rewriter, loc, ValueRange{andOp, aAndB}, true);
}

static Value extractMSB(OpBuilder &builder, Value val) {
  return builder.createOrFold<comb::ExtractOp>(
      val.getLoc(), val, val.getType().getIntOrFloatBitWidth() - 1, 1);
}

static Value extractOtherThanMSB(OpBuilder &builder, Value val) {
  return builder.createOrFold<comb::ExtractOp>(
      val.getLoc(), val, 0, val.getType().getIntOrFloatBitWidth() - 1);
}

namespace {
// A union of Value and IntegerAttr to cleanly handle constant values.
using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
} // namespace

// Return the number of unknown bits and populate the concatenated values.
static int64_t getNumUnknownBitsAndPopulateValues(
    Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
  // Constant or zero width value are all known.
  if (value.getType().isInteger(0))
    return 0;

  // Recursively count unknown bits for concat.
  if (auto concat = value.getDefiningOp<comb::ConcatOp>()) {
    int64_t totalUnknownBits = 0;
    for (auto concatInput : llvm::reverse(concat.getInputs())) {
      auto unknownBits =
          getNumUnknownBitsAndPopulateValues(concatInput, values);
      if (unknownBits < 0)
        return unknownBits;
      totalUnknownBits += unknownBits;
    }
    return totalUnknownBits;
  }

  // Constant value is known.
  if (auto constant = value.getDefiningOp<hw::ConstantOp>()) {
    values.push_back(constant.getValueAttr());
    return 0;
  }

  // Consider other operations as unknown bits.
  // TODO: We can handle replicate, extract, etc.
  values.push_back(value);
  return hw::getBitWidth(value.getType());
}

// Return a value that substitutes the unknown bits with the mask.
static APInt
substitueMaskToValues(size_t width,
                      llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
                      uint32_t mask) {
  uint32_t bitPos = 0, unknownPos = 0;
  APInt result(width, 0);
  for (auto constantOrValue : constantOrValues) {
    int64_t elemWidth;
    if (auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
      elemWidth = constant.getValue().getBitWidth();
      result.insertBits(constant.getValue(), bitPos);
    } else {
      elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
      assert(elemWidth >= 0 && "unknown bit width");
      assert(elemWidth + unknownPos < 32 && "unknown bit width too large");
      // Create a mask for the unknown bits.
      uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
      result.insertBits(APInt(elemWidth, usedBits), bitPos);
      unknownPos += elemWidth;
    }
    bitPos += elemWidth;
  }

  return result;
}

// Emulate a binary operation with unknown bits using a table lookup.
// This function enumerates all possible combinations of unknown bits and
// emulates the operation for each combination.
static LogicalResult emulateBinaryOpForUnknownBits(
    ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
    Operation *op,
    llvm::function_ref<APInt(const APInt &, const APInt &)> emulate) {
  SmallVector<ConstantOrValue> lhsValues, rhsValues;

  assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
         "op must be a single result binary operation");

  auto lhs = op->getOperand(0);
  auto rhs = op->getOperand(1);
  auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
  auto loc = op->getLoc();
  auto numLhsUnknownBits = getNumUnknownBitsAndPopulateValues(lhs, lhsValues);
  auto numRhsUnknownBits = getNumUnknownBitsAndPopulateValues(rhs, rhsValues);

  // If unknown bit width is detected, abort the lowering.
  if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
    return failure();

  int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
  if (totalUnknownBits > maxEmulationUnknownBits)
    return failure();

  SmallVector<Value> emulatedResults;
  emulatedResults.reserve(1 << totalUnknownBits);

  // Emulate all possible cases.
  DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
  auto getConstant = [&](const APInt &value) -> hw::ConstantOp {
    auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
    auto it = constantPool.find(attr);
    if (it != constantPool.end())
      return it->second;
    auto constant = hw::ConstantOp::create(rewriter, loc, value);
    constantPool[attr] = constant;
    return constant;
  };

  for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
       lhsMask < lhsMaskEnd; ++lhsMask) {
    APInt lhsValue = substitueMaskToValues(width, lhsValues, lhsMask);
    for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
         rhsMask < rhsMaskEnd; ++rhsMask) {
      APInt rhsValue = substitueMaskToValues(width, rhsValues, rhsMask);
      // Emulate.
      emulatedResults.push_back(getConstant(emulate(lhsValue, rhsValue)));
    }
  }

  // Create selectors for mux tree.
  SmallVector<Value> selectors;
  selectors.reserve(totalUnknownBits);
  for (auto &concatedValues : {rhsValues, lhsValues})
    for (auto valueOrConstant : concatedValues) {
      auto value = dyn_cast<Value>(valueOrConstant);
      if (!value)
        continue;
      extractBits(rewriter, value, selectors);
    }

  assert(totalUnknownBits == static_cast<int64_t>(selectors.size()) &&
         "number of selectors must match");
  auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
                                getConstant(APInt::getZero(width)));

  replaceOpAndCopyNamehint(rewriter, op, muxed);
  return success();
}

//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//

namespace {

/// Lower a comb::AndOp operation to synth::aig::AndInverterOp
struct CombAndOpConversion : OpConversionPattern<AndOp> {
  using OpConversionPattern<AndOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(AndOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    SmallVector<bool> nonInverts(adaptor.getInputs().size(), false);
    replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
        rewriter, op, adaptor.getInputs(), nonInverts);
    return success();
  }
};

/// Lower a comb::OrOp operation to synth::aig::AndInverterOp with invert flags
struct CombOrToAIGConversion : OpConversionPattern<OrOp> {
  using OpConversionPattern<OrOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(OrOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Implement Or using And and invert flags: a | b = ~(~a & ~b)
    SmallVector<bool> allInverts(adaptor.getInputs().size(), true);
    auto andOp = synth::aig::AndInverterOp::create(
        rewriter, op.getLoc(), adaptor.getInputs(), allInverts);
    replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
        rewriter, op, andOp,
        /*invert=*/true);
    return success();
  }
};

struct CombOrToMIGConversion : OpConversionPattern<OrOp> {
  using OpConversionPattern<OrOp>::OpConversionPattern;
  LogicalResult
  matchAndRewrite(OrOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (op.getNumOperands() != 2)
      return failure();
    SmallVector<Value, 3> inputs(adaptor.getInputs());
    auto one = hw::ConstantOp::create(
        rewriter, op.getLoc(),
        APInt::getAllOnes(hw::getBitWidth(op.getType())));
    inputs.push_back(one);
    std::array<bool, 3> inverts = {false, false, false};
    replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
        rewriter, op, inputs, inverts);
    return success();
  }
};

struct AndInverterToMIGConversion
    : OpConversionPattern<synth::aig::AndInverterOp> {
  using OpConversionPattern<synth::aig::AndInverterOp>::OpConversionPattern;
  LogicalResult
  matchAndRewrite(synth::aig::AndInverterOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (op.getNumOperands() > 2)
      return failure();
    if (op.getNumOperands() == 1) {
      SmallVector<bool, 1> inverts{op.getInverted()[0]};
      replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
          rewriter, op, adaptor.getInputs(), inverts);
      return success();
    }
    SmallVector<Value, 3> inputs(adaptor.getInputs());
    auto one = hw::ConstantOp::create(
        rewriter, op.getLoc(), APInt::getZero(hw::getBitWidth(op.getType())));
    inputs.push_back(one);
    SmallVector<bool, 3> inverts(adaptor.getInverted());
    inverts.push_back(false);
    replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
        rewriter, op, inputs, inverts);
    return success();
  }
};

/// Lower a comb::XorOp operation to AIG operations
struct CombXorOpConversion : OpConversionPattern<XorOp> {
  using OpConversionPattern<XorOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(XorOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (op.getNumOperands() != 2)
      return failure();
    // Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)

    // (a | b) = ~(~a & ~b)
    // (~a | ~b) = ~(a & b)
    auto inputs = adaptor.getInputs();
    SmallVector<bool> allInverts(inputs.size(), true);
    SmallVector<bool> allNotInverts(inputs.size(), false);

    auto notAAndNotB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
                                                         inputs, allInverts);
    auto aAndB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
                                                   inputs, allNotInverts);

    replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
        rewriter, op, notAAndNotB, aAndB,
        /*lhs_invert=*/true,
        /*rhs_invert=*/true);
    return success();
  }
};

template <typename OpTy>
struct CombLowerVariadicOp : OpConversionPattern<OpTy> {
  using OpConversionPattern<OpTy>::OpConversionPattern;
  using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
  LogicalResult
  matchAndRewrite(OpTy op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto result = lowerFullyAssociativeOp(op, op.getOperands(), rewriter);
    replaceOpAndCopyNamehint(rewriter, op, result);
    return success();
  }

  static Value lowerFullyAssociativeOp(OpTy op, OperandRange operands,
                                       ConversionPatternRewriter &rewriter) {
    Value lhs, rhs;
    switch (operands.size()) {
    case 0:
      llvm_unreachable("cannot be called with empty operand range");
      break;
    case 1:
      return operands[0];
    case 2:
      lhs = operands[0];
      rhs = operands[1];
      return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
    default:
      auto firstHalf = operands.size() / 2;
      lhs =
          lowerFullyAssociativeOp(op, operands.take_front(firstHalf), rewriter);
      rhs =
          lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), rewriter);
      return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
    }
  }
};

// Lower comb::MuxOp to AIG operations.
struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
  using OpConversionPattern<MuxOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(MuxOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Value cond = op.getCond();
    auto trueVal = op.getTrueValue();
    auto falseVal = op.getFalseValue();

    if (!op.getType().isInteger()) {
      // If the type of the mux is not integer, bitcast the operands first.
      auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
      trueVal =
          hw::BitcastOp::create(rewriter, op->getLoc(), widthType, trueVal);
      falseVal =
          hw::BitcastOp::create(rewriter, op->getLoc(), widthType, falseVal);
    }

    // Replicate condition if needed
    if (!trueVal.getType().isInteger(1))
      cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
                                       cond);

    // c ? a : b => (replicate(c) & a) | (~replicate(c) & b)
    auto lhs =
        synth::aig::AndInverterOp::create(rewriter, op.getLoc(), cond, trueVal);
    auto rhs = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), cond,
                                                 falseVal, true, false);

    Value result = comb::OrOp::create(rewriter, op.getLoc(), lhs, rhs);
    // Insert the bitcast if the type of the mux is not integer.
    if (result.getType() != op.getType())
      result =
          hw::BitcastOp::create(rewriter, op.getLoc(), op.getType(), result);
    replaceOpAndCopyNamehint(rewriter, op, result);
    return success();
  }
};

//===----------------------------------------------------------------------===//
// Adder Architecture Selection
//===----------------------------------------------------------------------===//

enum AdderArchitecture { RippleCarry, Sklanskey, KoggeStone, BrentKung };
AdderArchitecture determineAdderArch(Operation *op, int64_t width) {
  auto strAttr = op->getAttrOfType<StringAttr>("synth.test.arch");
  if (strAttr) {
    return llvm::StringSwitch<AdderArchitecture>(strAttr.getValue())
        .Case("SKLANSKEY", Sklanskey)
        .Case("KOGGE-STONE", KoggeStone)
        .Case("BRENT-KUNG", BrentKung)
        .Case("RIPPLE-CARRY", RippleCarry);
  }
  // Determine using width as a heuristic.
  // TODO: Perform a more thorough analysis to motivate the choices or
  // implement an adder synthesis algorithm to construct an optimal adder
  // under the given timing constraints - see the work of Zimmermann

  // For very small adders, overhead of a parallel prefix adder is likely not
  // worth it.
  if (width < 8)
    return AdderArchitecture::RippleCarry;

  // Sklanskey is a good compromise for high-performance, but has high fanout
  // which may lead to wiring congestion for very large adders.
  if (width <= 32)
    return AdderArchitecture::Sklanskey;

  // Kogge-Stone uses greater area than Sklanskey but has lower fanout thus
  // may be preferable for larger adders.
  return AdderArchitecture::KoggeStone;
}

//===----------------------------------------------------------------------===//
// Parallel Prefix Tree
//===----------------------------------------------------------------------===//

// Implement the Kogge-Stone parallel prefix tree
// Described in https://en.wikipedia.org/wiki/Kogge%E2%80%93Stone_adder
// Slightly better delay than Brent-Kung, but more area.
void lowerKoggeStonePrefixTree(OpBuilder &builder, Location loc,
                               SmallVector<Value> &pPrefix,
                               SmallVector<Value> &gPrefix) {

  auto width = static_cast<int64_t>(pPrefix.size());
  assert(width == static_cast<int64_t>(gPrefix.size()));
  SmallVector<Value> pPrefixNew = pPrefix;
  SmallVector<Value> gPrefixNew = gPrefix;

  // Kogge-Stone parallel prefix computation
  for (int64_t stride = 1; stride < width; stride *= 2) {

    for (int64_t i = stride; i < width; ++i) {
      int64_t j = i - stride;

      // Group generate: g_i OR (p_i AND g_j)
      Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
      gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);

      // Group propagate: p_i AND p_j
      pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
    }

    pPrefix = pPrefixNew;
    gPrefix = gPrefixNew;
  }

  LLVM_DEBUG({
    int64_t stage = 0;
    for (int64_t stride = 1; stride < width; stride *= 2) {
      llvm::dbgs()
          << "--------------------------------------- Kogge-Stone Stage "
          << stage << "\n";
      for (int64_t i = stride; i < width; ++i) {
        int64_t j = i - stride;
        // Group generate: g_i OR (p_i AND g_j)
        llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
                     << " OR (P" << i << stage << " AND G" << j << stage
                     << ")\n";

        // Group propagate: p_i AND p_j
        llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
                     << " AND P" << j << stage << "\n";
      }
      ++stage;
    }
  });
}

// Implement the Sklansky parallel prefix tree
// High fan-out, low depth, low area
void lowerSklanskeyPrefixTree(OpBuilder &builder, Location loc,
                              SmallVector<Value> &pPrefix,
                              SmallVector<Value> &gPrefix) {
  auto width = static_cast<int64_t>(pPrefix.size());
  assert(width == static_cast<int64_t>(gPrefix.size()));
  SmallVector<Value> pPrefixNew = pPrefix;
  SmallVector<Value> gPrefixNew = gPrefix;
  for (int64_t stride = 1; stride < width; stride *= 2) {
    for (int64_t i = stride; i < width; i += 2 * stride) {
      for (int64_t k = 0; k < stride && i + k < width; ++k) {
        int64_t idx = i + k;
        int64_t j = i - 1;

        // Group generate: g_idx OR (p_idx AND g_j)
        Value andPG =
            comb::AndOp::create(builder, loc, pPrefix[idx], gPrefix[j]);
        gPrefixNew[idx] = comb::OrOp::create(builder, loc, gPrefix[idx], andPG);

        // Group propagate: p_idx AND p_j
        pPrefixNew[idx] =
            comb::AndOp::create(builder, loc, pPrefix[idx], pPrefix[j]);
      }
    }

    pPrefix = pPrefixNew;
    gPrefix = gPrefixNew;
  }

  LLVM_DEBUG({
    int64_t stage = 0;
    for (int64_t stride = 1; stride < width; stride *= 2) {
      llvm::dbgs() << "--------------------------------------- Sklanskey Stage "
                   << stage << "\n";
      for (int64_t i = stride; i < width; i += 2 * stride) {
        for (int64_t k = 0; k < stride && i + k < width; ++k) {
          int64_t idx = i + k;
          int64_t j = i - 1;
          // Group generate: g_i OR (p_i AND g_j)
          llvm::dbgs() << "G" << idx << stage + 1 << " = G" << idx << stage
                       << " OR (P" << idx << stage << " AND G" << j << stage
                       << ")\n";

          // Group propagate: p_i AND p_j
          llvm::dbgs() << "P" << idx << stage + 1 << " = P" << idx << stage
                       << " AND P" << j << stage << "\n";
        }
      }
      ++stage;
    }
  });
}

// Implement the Brent-Kung parallel prefix tree
// Described in https://en.wikipedia.org/wiki/Brent%E2%80%93Kung_adder
// Slightly worse delay than Kogge-Stone, but less area.
void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
                              SmallVector<Value> &pPrefix,
                              SmallVector<Value> &gPrefix) {
  auto width = static_cast<int64_t>(pPrefix.size());
  assert(width == static_cast<int64_t>(gPrefix.size()));
  SmallVector<Value> pPrefixNew = pPrefix;
  SmallVector<Value> gPrefixNew = gPrefix;
  // Brent-Kung parallel prefix computation
  // Forward phase
  int64_t stride;
  for (stride = 1; stride < width; stride *= 2) {
    for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
      int64_t j = i - stride;

      // Group generate: g_i OR (p_i AND g_j)
      Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
      gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);

      // Group propagate: p_i AND p_j
      pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
    }
    pPrefix = pPrefixNew;
    gPrefix = gPrefixNew;
  }

  // Backward phase
  for (; stride > 0; stride /= 2) {
    for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
      int64_t j = i - stride;

      // Group generate: g_i OR (p_i AND g_j)
      Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
      gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);

      // Group propagate: p_i AND p_j
      pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
    }
    pPrefix = pPrefixNew;
    gPrefix = gPrefixNew;
  }

  LLVM_DEBUG({
    int64_t stage = 0;
    for (stride = 1; stride < width; stride *= 2) {
      llvm::dbgs() << "--------------------------------------- Brent-Kung FW "
                   << stage << " : Stride " << stride << "\n";
      for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
        int64_t j = i - stride;

        // Group generate: g_i OR (p_i AND g_j)
        llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
                     << " OR (P" << i << stage << " AND G" << j << stage
                     << ")\n";

        // Group propagate: p_i AND p_j
        llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
                     << " AND P" << j << stage << "\n";
      }
      ++stage;
    }

    for (; stride > 0; stride /= 2) {
      if (stride * 3 - 1 < width)
        llvm::dbgs() << "--------------------------------------- Brent-Kung BW "
                     << stage << " : Stride " << stride << "\n";

      for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
        int64_t j = i - stride;

        // Group generate: g_i OR (p_i AND g_j)
        llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
                     << " OR (P" << i << stage << " AND G" << j << stage
                     << ")\n";

        // Group propagate: p_i AND p_j
        llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
                     << " AND P" << j << stage << "\n";
      }
      --stage;
    }
  });
}

// TODO: Generalize to other parallel prefix trees.
class LazyKoggeStonePrefixTree {
public:
  LazyKoggeStonePrefixTree(OpBuilder &builder, Location loc, int64_t width,
                           ArrayRef<Value> pPrefix, ArrayRef<Value> gPrefix)
      : builder(builder), loc(loc), width(width) {
    assert(width > 0 && "width must be positive");
    for (int64_t i = 0; i < width; ++i)
      prefixCache[{0, i}] = {pPrefix[i], gPrefix[i]};
  }

  // Get the final group and propagate values for bit i.
  std::pair<Value, Value> getFinal(int64_t i) {
    assert(i >= 0 && i < width && "i out of bounds");
    // Final level is ceil(log2(width)) in Kogge-Stone.
    return getGroupAndPropagate(llvm::Log2_64_Ceil(width), i);
  }

private:
  // Recursively get the group and propagate values for bit i at level `level`.
  // Level 0 is the initial level with the input propagate and generate values.
  // Level n computes the group and propagate values for a stride of 2^(n-1).
  // Uses memoization to cache intermediate results.
  std::pair<Value, Value> getGroupAndPropagate(int64_t level, int64_t i);
  OpBuilder &builder;
  Location loc;
  int64_t width;
  DenseMap<std::pair<int64_t, int64_t>, std::pair<Value, Value>> prefixCache;
};

std::pair<Value, Value>
LazyKoggeStonePrefixTree::getGroupAndPropagate(int64_t level, int64_t i) {
  assert(i < width && "i out of bounds");
  auto key = std::make_pair(level, i);
  auto it = prefixCache.find(key);
  if (it != prefixCache.end())
    return it->second;

  assert(level > 0 && "If the level is 0, we should have hit the cache");

  int64_t previousStride = 1ULL << (level - 1);
  if (i < previousStride) {
    // No dependency, just copy from the previous level.
    auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
    prefixCache[key] = {propagateI, generateI};
    return prefixCache[key];
  }
  // Get the dependency index.
  int64_t j = i - previousStride;
  auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
  auto [propagateJ, generateJ] = getGroupAndPropagate(level - 1, j);
  // Group generate: g_i OR (p_i AND g_j)
  Value andPG = comb::AndOp::create(builder, loc, propagateI, generateJ);
  Value newGenerate = comb::OrOp::create(builder, loc, generateI, andPG);
  // Group propagate: p_i AND p_j
  Value newPropagate =
      comb::AndOp::create(builder, loc, propagateI, propagateJ);
  prefixCache[key] = {newPropagate, newGenerate};
  return prefixCache[key];
}

template <bool lowerToMIG>
struct CombAddOpConversion : OpConversionPattern<AddOp> {
  using OpConversionPattern<AddOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(AddOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto inputs = adaptor.getInputs();
    // Lower only when there are two inputs.
    // Variadic operands must be lowered in a different pattern.
    if (inputs.size() != 2)
      return failure();

    auto width = op.getType().getIntOrFloatBitWidth();
    // Skip a zero width value.
    if (width == 0) {
      replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
                                                        op.getType(), 0);
      return success();
    }

    // Check if the architecture is specified by an attribute.
    auto arch = determineAdderArch(op, width);
    if (arch == AdderArchitecture::RippleCarry)
      return lowerRippleCarryAdder(op, inputs, rewriter);
    return lowerParallelPrefixAdder(op, inputs, rewriter);
  }

  // Implement a basic ripple-carry adder for small bitwidths.
  LogicalResult
  lowerRippleCarryAdder(comb::AddOp op, ValueRange inputs,
                        ConversionPatternRewriter &rewriter) const {
    auto width = op.getType().getIntOrFloatBitWidth();
    // Implement a naive Ripple-carry full adder.
    Value carry;

    auto aBits = extractBits(rewriter, inputs[0]);
    auto bBits = extractBits(rewriter, inputs[1]);
    SmallVector<Value> results;
    results.resize(width);
    for (int64_t i = 0; i < width; ++i) {
      SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
      if (carry)
        xorOperands.push_back(carry);

      // sum[i] = xor(carry[i-1], a[i], b[i])
      // NOTE: The result is stored in reverse order.
      results[width - i - 1] =
          comb::XorOp::create(rewriter, op.getLoc(), xorOperands, true);

      // If this is the last bit, we are done.
      if (i == width - 1)
        break;

      // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
      if (!carry) {
        // This is the first bit, so the carry is the next carry.
        carry = comb::AndOp::create(rewriter, op.getLoc(),
                                    ValueRange{aBits[i], bBits[i]}, true);
        continue;
      }

      carry = createMajorityFunction(rewriter, op.getLoc(), aBits[i], bBits[i],
                                     carry, lowerToMIG);
    }
    LLVM_DEBUG(llvm::dbgs() << "Lower comb.add to Ripple-Carry Adder of width "
                            << width << "\n");

    replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
    return success();
  }

  // Implement a parallel prefix adder - with Kogge-Stone or Brent-Kung trees
  // Will introduce unused signals for the carry bits but these will be removed
  // by the AIG pass.
  LogicalResult
  lowerParallelPrefixAdder(comb::AddOp op, ValueRange inputs,
                           ConversionPatternRewriter &rewriter) const {
    auto width = op.getType().getIntOrFloatBitWidth();

    auto aBits = extractBits(rewriter, inputs[0]);
    auto bBits = extractBits(rewriter, inputs[1]);

    // Construct propagate (p) and generate (g) signals
    SmallVector<Value> p, g;
    p.reserve(width);
    g.reserve(width);

    for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
      // p_i = a_i XOR b_i
      p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
      // g_i = a_i AND b_i
      g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
    }

    LLVM_DEBUG({
      llvm::dbgs() << "Lower comb.add to Parallel-Prefix of width " << width
                   << "\n--------------------------------------- Init\n";

      for (int64_t i = 0; i < width; ++i) {
        // p_i = a_i XOR b_i
        llvm::dbgs() << "P0" << i << " = A" << i << " XOR B" << i << "\n";
        // g_i = a_i AND b_i
        llvm::dbgs() << "G0" << i << " = A" << i << " AND B" << i << "\n";
      }
    });

    // Create copies of p and g for the prefix computation
    SmallVector<Value> pPrefix = p;
    SmallVector<Value> gPrefix = g;

    // Check if the architecture is specified by an attribute.
    auto arch = determineAdderArch(op, width);

    switch (arch) {
    case AdderArchitecture::RippleCarry:
      llvm_unreachable("Ripple-Carry should be handled separately");
      break;
    case AdderArchitecture::Sklanskey:
      lowerSklanskeyPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
      break;
    case AdderArchitecture::KoggeStone:
      lowerKoggeStonePrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
      break;
    case AdderArchitecture::BrentKung:
      lowerBrentKungPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
      break;
    }

    // Generate result sum bits
    // NOTE: The result is stored in reverse order.
    SmallVector<Value> results;
    results.resize(width);
    // Sum bit 0 is just p[0] since carry_in = 0
    results[width - 1] = p[0];

    // For remaining bits, sum_i = p_i XOR g_(i-1)
    // The carry into position i is the group generate from position i-1
    for (int64_t i = 1; i < width; ++i)
      results[width - 1 - i] =
          comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);

    replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);

    LLVM_DEBUG({
      llvm::dbgs() << "--------------------------------------- Completion\n"
                   << "RES0 = P0\n";
      for (int64_t i = 1; i < width; ++i)
        llvm::dbgs() << "RES" << i << " = P" << i << " XOR G" << i - 1 << "\n";
    });

    return success();
  }
};

struct CombMulOpConversion : OpConversionPattern<MulOp> {
  using OpConversionPattern<MulOp>::OpConversionPattern;
  using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
  LogicalResult
  matchAndRewrite(MulOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (adaptor.getInputs().size() != 2)
      return failure();

    Location loc = op.getLoc();
    Value a = adaptor.getInputs()[0];
    Value b = adaptor.getInputs()[1];
    unsigned width = op.getType().getIntOrFloatBitWidth();

    // Skip a zero width value.
    if (width == 0) {
      rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
      return success();
    }

    // Extract individual bits from operands
    SmallVector<Value> aBits = extractBits(rewriter, a);
    SmallVector<Value> bBits = extractBits(rewriter, b);

    auto falseValue = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));

    // Generate partial products
    SmallVector<SmallVector<Value>> partialProducts;
    partialProducts.reserve(width);
    for (unsigned i = 0; i < width; ++i) {
      SmallVector<Value> row(i, falseValue);
      row.reserve(width);
      // Generate partial product bits
      for (unsigned j = 0; i + j < width; ++j)
        row.push_back(
            rewriter.createOrFold<comb::AndOp>(loc, aBits[j], bBits[i]));

      partialProducts.push_back(row);
    }

    // If the width is 1, we are done.
    if (width == 1) {
      rewriter.replaceOp(op, partialProducts[0][0]);
      return success();
    }

    // Wallace tree reduction - reduce to two addends.
    datapath::CompressorTree comp(width, partialProducts, loc);
    auto addends = comp.compressToHeight(rewriter, 2);

    // Sum the two addends using a carry-propagate adder
    auto newAdd = comb::AddOp::create(rewriter, loc, addends, true);
    replaceOpAndCopyNamehint(rewriter, op, newAdd);
    return success();
  }
};

template <typename OpTy>
struct DivModOpConversionBase : OpConversionPattern<OpTy> {
  DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
      : OpConversionPattern<OpTy>(context),
        maxEmulationUnknownBits(maxEmulationUnknownBits) {
    assert(maxEmulationUnknownBits < 32 &&
           "maxEmulationUnknownBits must be less than 32");
  }
  const int64_t maxEmulationUnknownBits;
};

struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
  using DivModOpConversionBase<DivUOp>::DivModOpConversionBase;
  LogicalResult
  matchAndRewrite(DivUOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Check if the divisor is a power of two.
    if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter)))
      return success();

    // When rhs is not power of two and the number of unknown bits are small,
    // create a mux tree that emulates all possible cases.
    return emulateBinaryOpForUnknownBits(
        rewriter, maxEmulationUnknownBits, op,
        [](const APInt &lhs, const APInt &rhs) {
          // Division by zero is undefined, just return zero.
          if (rhs.isZero())
            return APInt::getZero(rhs.getBitWidth());
          return lhs.udiv(rhs);
        });
  }
};

struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
  using DivModOpConversionBase<ModUOp>::DivModOpConversionBase;
  LogicalResult
  matchAndRewrite(ModUOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Check if the divisor is a power of two.
    if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter)))
      return success();

    // When rhs is not power of two and the number of unknown bits are small,
    // create a mux tree that emulates all possible cases.
    return emulateBinaryOpForUnknownBits(
        rewriter, maxEmulationUnknownBits, op,
        [](const APInt &lhs, const APInt &rhs) {
          // Division by zero is undefined, just return zero.
          if (rhs.isZero())
            return APInt::getZero(rhs.getBitWidth());
          return lhs.urem(rhs);
        });
  }
};

struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
  using DivModOpConversionBase<DivSOp>::DivModOpConversionBase;

  LogicalResult
  matchAndRewrite(DivSOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Currently only lower with emulation.
    // TODO: Implement a signed division lowering at least for power of two.
    return emulateBinaryOpForUnknownBits(
        rewriter, maxEmulationUnknownBits, op,
        [](const APInt &lhs, const APInt &rhs) {
          // Division by zero is undefined, just return zero.
          if (rhs.isZero())
            return APInt::getZero(rhs.getBitWidth());
          return lhs.sdiv(rhs);
        });
  }
};

struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
  using DivModOpConversionBase<ModSOp>::DivModOpConversionBase;
  LogicalResult
  matchAndRewrite(ModSOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Currently only lower with emulation.
    // TODO: Implement a signed modulus lowering at least for power of two.
    return emulateBinaryOpForUnknownBits(
        rewriter, maxEmulationUnknownBits, op,
        [](const APInt &lhs, const APInt &rhs) {
          // Division by zero is undefined, just return zero.
          if (rhs.isZero())
            return APInt::getZero(rhs.getBitWidth());
          return lhs.srem(rhs);
        });
  }
};

struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
  using OpConversionPattern<ICmpOp>::OpConversionPattern;

  // Simple comparator for small bit widths
  static Value constructRippleCarry(Location loc, Value a, Value b,
                                    bool includeEq,
                                    ConversionPatternRewriter &rewriter) {
    // Construct following unsigned comparison expressions.
    // a <= b  ==> (~a[n] &  b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0])
    // a <  b  ==> (~a[n] &  b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0])
    auto aBits = extractBits(rewriter, a);
    auto bBits = extractBits(rewriter, b);
    Value acc = hw::ConstantOp::create(rewriter, loc, APInt(1, includeEq));

    for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
      auto aBitXorBBit =
          rewriter.createOrFold<comb::XorOp>(loc, aBit, bBit, true);
      auto aEqualB = rewriter.createOrFold<synth::aig::AndInverterOp>(
          loc, aBitXorBBit, true);
      auto pred = rewriter.createOrFold<synth::aig::AndInverterOp>(
          loc, aBit, bBit, true, false);

      auto aBitAndBBit = rewriter.createOrFold<comb::AndOp>(
          loc, ValueRange{aEqualB, acc}, true);
      acc = rewriter.createOrFold<comb::OrOp>(loc, pred, aBitAndBBit, true);
    }
    return acc;
  }

  // Compute prefix comparison using parallel prefix algorithm
  // Note: This generates all intermediate prefix values even though we only
  // need the final result. Optimizing this to skip intermediate computations
  // is non-trivial because each iteration depends on results from previous
  // iterations. We rely on DCE passes to remove unused operations.
  // TODO: Lazily compute only the required prefix values. Kogge-Stone is
  // already implemented in a lazy manner below, but other architectures can
  // also be optimized.
  static Value computePrefixComparison(ConversionPatternRewriter &rewriter,
                                       Location loc, SmallVector<Value> pPrefix,
                                       SmallVector<Value> gPrefix,
                                       bool includeEq, AdderArchitecture arch) {
    auto width = pPrefix.size();
    Value finalGroup, finalPropagate;
    // Apply the appropriate prefix tree algorithm
    switch (arch) {
    case AdderArchitecture::RippleCarry:
      llvm_unreachable("Ripple-Carry should be handled separately");
      break;
    case AdderArchitecture::Sklanskey: {
      lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix);
      finalGroup = gPrefix[width - 1];
      finalPropagate = pPrefix[width - 1];
      break;
    }
    case AdderArchitecture::KoggeStone:
      // Use lazy Kogge-Stone implementation to avoid computing all
      // intermediate prefix values.
      std::tie(finalPropagate, finalGroup) =
          LazyKoggeStonePrefixTree(rewriter, loc, width, pPrefix, gPrefix)
              .getFinal(width - 1);
      break;
    case AdderArchitecture::BrentKung: {
      lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix);
      finalGroup = gPrefix[width - 1];
      finalPropagate = pPrefix[width - 1];
      break;
    }
    }

    // Final result: `finalGroup` gives us "a < b"
    if (includeEq) {
      // a <= b iff (a < b) OR (a == b)
      // a == b iff `finalPropagate` (all bits are equal)
      return comb::OrOp::create(rewriter, loc, finalGroup, finalPropagate);
    }
    // a < b iff `finalGroup`
    return finalGroup;
  }

  // Construct an unsigned comparator using either ripple-carry or
  // parallel-prefix architecture. Comparison uses parallel prefix tree as an
  // internal component, so use `AdderArchitecture` enum to select architecture.
  static Value constructUnsignedCompare(Operation *op, Location loc, Value a,
                                        Value b, bool isLess, bool includeEq,
                                        ConversionPatternRewriter &rewriter) {
    // Ensure a <= b by swapping for simplicity.
    if (!isLess)
      std::swap(a, b);
    auto width = a.getType().getIntOrFloatBitWidth();

    // Check if the architecture is specified by an attribute.
    auto arch = determineAdderArch(op, width);
    if (arch == AdderArchitecture::RippleCarry)
      return constructRippleCarry(loc, a, b, includeEq, rewriter);

    // For larger widths, use parallel prefix tree
    auto aBits = extractBits(rewriter, a);
    auto bBits = extractBits(rewriter, b);

    // For comparison, we compute:
    // - Equal bits: eq_i = ~(a_i ^ b_i)
    // - Greater bits: gt_i = ~a_i & b_i (a_i < b_i)
    // - Propagate: p_i = eq_i (equality propagates)
    // - Generate: g_i = gt_i (greater-than generates)
    SmallVector<Value> eq, gt;
    eq.reserve(width);
    gt.reserve(width);

    auto one =
        hw::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(1), 1);

    for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
      // eq_i = ~(a_i ^ b_i) = a_i == b_i
      auto xorBit = comb::XorOp::create(rewriter, loc, aBit, bBit);
      eq.push_back(comb::XorOp::create(rewriter, loc, xorBit, one));

      // gt_i = ~a_i & b_i = a_i < b_i
      auto notA = comb::XorOp::create(rewriter, loc, aBit, one);
      gt.push_back(comb::AndOp::create(rewriter, loc, notA, bBit));
    }

    return computePrefixComparison(rewriter, loc, std::move(eq), std::move(gt),
                                   includeEq, arch);
  }

  LogicalResult
  matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto lhs = adaptor.getLhs();
    auto rhs = adaptor.getRhs();

    switch (op.getPredicate()) {
    default:
      return failure();

    case ICmpPredicate::eq:
    case ICmpPredicate::ceq: {
      // a == b  ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ...
      auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
      auto xorBits = extractBits(rewriter, xorOp);
      SmallVector<bool> allInverts(xorBits.size(), true);
      replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
          rewriter, op, xorBits, allInverts);
      return success();
    }

    case ICmpPredicate::ne:
    case ICmpPredicate::cne: {
      // a != b  ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ...
      auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
      replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
          rewriter, op, extractBits(rewriter, xorOp), true);
      return success();
    }

    case ICmpPredicate::uge:
    case ICmpPredicate::ugt:
    case ICmpPredicate::ule:
    case ICmpPredicate::ult: {
      bool isLess = op.getPredicate() == ICmpPredicate::ult ||
                    op.getPredicate() == ICmpPredicate::ule;
      bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
                       op.getPredicate() == ICmpPredicate::ule;
      replaceOpAndCopyNamehint(rewriter, op,
                               constructUnsignedCompare(op, op.getLoc(), lhs,
                                                        rhs, isLess, includeEq,
                                                        rewriter));
      return success();
    }
    case ICmpPredicate::slt:
    case ICmpPredicate::sle:
    case ICmpPredicate::sgt:
    case ICmpPredicate::sge: {
      if (lhs.getType().getIntOrFloatBitWidth() == 0)
        return rewriter.notifyMatchFailure(
            op.getLoc(), "i0 signed comparison is unsupported");
      bool isLess = op.getPredicate() == ICmpPredicate::slt ||
                    op.getPredicate() == ICmpPredicate::sle;
      bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
                       op.getPredicate() == ICmpPredicate::sle;

      // Get a sign bit
      auto signA = extractMSB(rewriter, lhs);
      auto signB = extractMSB(rewriter, rhs);
      auto aRest = extractOtherThanMSB(rewriter, lhs);
      auto bRest = extractOtherThanMSB(rewriter, rhs);

      // Compare magnitudes (all bits except sign)
      auto sameSignResult = constructUnsignedCompare(
          op, op.getLoc(), aRest, bRest, isLess, includeEq, rewriter);

      // XOR of signs: true if signs are different
      auto signsDiffer =
          comb::XorOp::create(rewriter, op.getLoc(), signA, signB);

      // Result when signs are different
      Value diffSignResult = isLess ? signA : signB;

      // Final result: choose based on whether signs differ
      replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
          rewriter, op, signsDiffer, diffSignResult, sameSignResult);
      return success();
    }
    }
  }
};

struct CombParityOpConversion : OpConversionPattern<ParityOp> {
  using OpConversionPattern<ParityOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ParityOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Parity is the XOR of all bits.
    replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
        rewriter, op, extractBits(rewriter, adaptor.getInput()), true);
    return success();
  }
};

struct CombShlOpConversion : OpConversionPattern<comb::ShlOp> {
  using OpConversionPattern<comb::ShlOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(comb::ShlOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto width = op.getType().getIntOrFloatBitWidth();
    auto lhs = adaptor.getLhs();
    auto result = createShiftLogic</*isLeftShift=*/true>(
        rewriter, op.getLoc(), adaptor.getRhs(), width,
        /*getPadding=*/
        [&](int64_t index) {
          // Don't create zero width value.
          if (index == 0)
            return Value();
          // Padding is 0 for left shift.
          return rewriter.createOrFold<hw::ConstantOp>(
              op.getLoc(), rewriter.getIntegerType(index), 0);
        },
        /*getExtract=*/
        [&](int64_t index) {
          assert(index < width && "index out of bounds");
          // Exract the bits from LSB.
          return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, 0,
                                                        width - index);
        });

    replaceOpAndCopyNamehint(rewriter, op, result);
    return success();
  }
};

struct CombShrUOpConversion : OpConversionPattern<comb::ShrUOp> {
  using OpConversionPattern<comb::ShrUOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(comb::ShrUOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto width = op.getType().getIntOrFloatBitWidth();
    auto lhs = adaptor.getLhs();
    auto result = createShiftLogic</*isLeftShift=*/false>(
        rewriter, op.getLoc(), adaptor.getRhs(), width,
        /*getPadding=*/
        [&](int64_t index) {
          // Don't create zero width value.
          if (index == 0)
            return Value();
          // Padding is 0 for right shift.
          return rewriter.createOrFold<hw::ConstantOp>(
              op.getLoc(), rewriter.getIntegerType(index), 0);
        },
        /*getExtract=*/
        [&](int64_t index) {
          assert(index < width && "index out of bounds");
          // Exract the bits from MSB.
          return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
                                                        width - index);
        });

    replaceOpAndCopyNamehint(rewriter, op, result);
    return success();
  }
};

struct CombShrSOpConversion : OpConversionPattern<comb::ShrSOp> {
  using OpConversionPattern<comb::ShrSOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(comb::ShrSOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto width = op.getType().getIntOrFloatBitWidth();
    if (width == 0)
      return rewriter.notifyMatchFailure(op.getLoc(),
                                         "i0 signed shift is unsupported");
    auto lhs = adaptor.getLhs();
    // Get the sign bit.
    auto sign =
        rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);

    // NOTE: The max shift amount is width - 1 because the sign bit is
    // already shifted out.
    auto result = createShiftLogic</*isLeftShift=*/false>(
        rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
        /*getPadding=*/
        [&](int64_t index) {
          return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
                                                          index + 1);
        },
        /*getExtract=*/
        [&](int64_t index) {
          return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
                                                        width - index - 1);
        });

    replaceOpAndCopyNamehint(rewriter, op, result);
    return success();
  }
};

} // namespace

//===----------------------------------------------------------------------===//
// Convert Comb to AIG pass
//===----------------------------------------------------------------------===//

namespace {
struct ConvertCombToSynthPass
    : public impl::ConvertCombToSynthBase<ConvertCombToSynthPass> {
  void runOnOperation() override;
  using ConvertCombToSynthBase<ConvertCombToSynthPass>::ConvertCombToSynthBase;
};
} // namespace

static void
populateCombToAIGConversionPatterns(RewritePatternSet &patterns,
                                    uint32_t maxEmulationUnknownBits,
                                    bool lowerToMIG) {
  patterns.add<
      // Bitwise Logical Ops
      CombAndOpConversion, CombXorOpConversion, CombMuxOpConversion,
      CombParityOpConversion,
      // Arithmetic Ops
      CombMulOpConversion, CombICmpOpConversion,
      // Shift Ops
      CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
      // Variadic ops that must be lowered to binary operations
      CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
      CombLowerVariadicOp<MulOp>>(patterns.getContext());

  patterns.add(comb::convertSubToAdd);

  if (lowerToMIG) {
    patterns.add<CombOrToMIGConversion, CombLowerVariadicOp<OrOp>,
                 AndInverterToMIGConversion,
                 circt::synth::AndInverterVariadicOpConversion,
                 CombAddOpConversion</*useMIG=*/true>>(patterns.getContext());
  } else {
    patterns.add<CombOrToAIGConversion, CombAddOpConversion</*useMIG=*/false>>(
        patterns.getContext());
  }

  // Add div/mod patterns with a threshold given by the pass option.
  patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
               CombModSOpConversion>(patterns.getContext(),
                                     maxEmulationUnknownBits);
}

void ConvertCombToSynthPass::runOnOperation() {
  ConversionTarget target(getContext());

  // Comb is source dialect.
  target.addIllegalDialect<comb::CombDialect>();
  // Keep data movement operations like Extract, Concat and Replicate.
  target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
                    hw::BitcastOp, hw::ConstantOp>();

  // Treat array operations as illegal. Strictly speaking, other than array
  // get operation with non-const index are legal in AIG but array types
  // prevent a bunch of optimizations so just lower them to integer
  // operations. It's required to run HWAggregateToComb pass before this pass.
  target.addIllegalOp<hw::ArrayGetOp, hw::ArrayCreateOp, hw::ArrayConcatOp,
                      hw::AggregateConstantOp>();

  target.addLegalDialect<synth::SynthDialect>();

  if (targetIR == CombToSynthTargetIR::AIG) {
    // AIG is target dialect.
    target.addIllegalOp<synth::mig::MajorityInverterOp>();
  } else if (targetIR == CombToSynthTargetIR::MIG) {
    target.addIllegalOp<synth::aig::AndInverterOp>();
  }

  // If additional legal ops are specified, add them to the target.
  if (!additionalLegalOps.empty())
    for (const auto &opName : additionalLegalOps)
      target.addLegalOp(OperationName(opName, &getContext()));

  RewritePatternSet patterns(&getContext());
  populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits,
                                      targetIR == CombToSynthTargetIR::MIG);

  if (failed(mlir::applyPartialConversion(getOperation(), target,
                                          std::move(patterns))))
    return signalPassFailure();
}
